{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Implementing DQN variants\n",
    "\n",
    "We just learned how to implement DQN using stable baselines. Now, let's see how to\n",
    "implement the variants of DQN such as double DQN, DQN with prioritized experience\n",
    "replay and dueling DQN. Implementing DQN variants is very simple with the baselines"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from stable_baselines import DQN"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Since we are dealing with Atari games we can use a convolutional neural network instead\n",
    "of a vanilla neural network. So, we use CnnPolicy:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from stable_baselines.deepq.policies import CnnPolicy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We learned that we preprocess the game screen before feeding it to the agent. With\n",
    "baselines, we don't have to preprocess manually, instead, we can make use of make_atari\n",
    "module which takes care of preprocessing the game screen:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from stable_baselines.common.atari_wrappers import make_atari"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, let's create an Atari game environment. Let's create the Ice Hockey game\n",
    "environment:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "env = make_atari('IceHockeyNoFrameskip-v4')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, we define our keyword arguments as shown below: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "kwargs = {\"double_q\": True, \"prioritized_replay\": True, \"policy_kwargs\": dict(dueling=True)}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, while instantiating our agent, we just need to pass the keyword arguments: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "agent = DQN(CnnPolicy, env, verbose=1, **kwargs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Train the agent:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "agent.learn(total_timesteps=25000)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "That's it! Now we have the dueling double DQN with prioritized experience replay. \n",
    "\n",
    "After training the agent, we can have a look at how our trained agent performs in the\n",
    "environment:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "state = env.reset()\n",
    "while True:\n",
    "    action, _ = agent.predict(state)\n",
    "    next_state, reward, done, info = env.step(action)\n",
    "    state = next_state\n",
    "    env.render()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
